import torch
import torch.nn.functional as F

from .discriminator import Discriminator
from .gnn_encoder import AvgReadout
import numpy as np


# %%
class DGITask(torch.nn.Module):
    def __init__(self, data, encoder, embedding_size, device):
        super(DGITask, self).__init__()
        self.name = 'deep graph infomax'
        self.data = data
        self.dataset_names = [dataset.name for dataset in self.data.datasets]
        self.device = device

        self.encoder = encoder
        self.read = AvgReadout()
        self.predictor = Discriminator(embedding_size).to(self.device)
        self.pseudo_labels = self.create_pseudo_labels()

    def create_pseudo_labels(self):
        labels = []
        for dataset in self.data.datasets:
            num_nodes = dataset.data.num_nodes
            lbl_1 = torch.ones(num_nodes)
            lbl_2 = torch.zeros(num_nodes)
            labels.append(torch.cat((lbl_1, lbl_2)).to(self.device))
        return labels

    def get_loss(self, embeddings, dataset_name):
        index = self.dataset_names.index(dataset_name)
        num_nodes = self.data.datasets[index].data.num_nodes

        self.train()
        idx = np.random.permutation(num_nodes)
        features_corrupted = self.data.datasets[index].data.x[idx, :]
        embeddings_corrupted = self.encoder(features_corrupted, self.data.datasets[index].data.edge_index)

        logits = self.predictor(
            torch.sigmoid(self.read(embeddings)),
            embeddings,
            embeddings_corrupted
        )
        loss = F.binary_cross_entropy_with_logits(logits, self.pseudo_labels[index])
        return loss


class DGISample(torch.nn.Module):

    def __init__(self, data, processed_data, encoder, nhid, device, **kwargs):
        super(DGISample, self).__init__()
        self.gcn = encoder
        self.data = data
        self.processed_data = processed_data
        self.device = device

        self.read = AvgReadout()
        self.sigm = torch.nn.Sigmoid()
        self.disc = Discriminator(nhid)
        self.b_xent = torch.nn.BCEWithLogitsLoss()

        if kwargs['args'].dataset in ['reddit', 'arxiv']:
            self.sample_size = 2000
        else:
            self.sample_size = 2000
        self.pseudo_labels = self.get_label()
        self.b_xent = torch.nn.BCEWithLogitsLoss()
        self.num_nodes = data.adj.shape[0]

    def get_label(self):
        lbl_1 = torch.ones(self.sample_size)
        lbl_2 = torch.zeros(self.sample_size)
        lbl = torch.cat((lbl_1, lbl_2))
        return lbl.to(self.device)

    def make_loss(self, x):
        features = self.processed_data.features
        adj = self.processed_data.adj_norm
        nb_nodes = features.shape[0]

        self.train()
        idx = np.random.permutation(nb_nodes)
        shuf_fts = features[idx, :]

        logits = self.forward(features, shuf_fts, adj, None, None, None, None)
        loss = self.b_xent(logits, self.pseudo_labels)
        # print('Loss:', loss.item())
        return loss

    def forward(self, seq1, seq2, adj, sparse, msk, samp_bias1, samp_bias2):
        idx = np.random.default_rng().choice(self.num_nodes,
                self.sample_size, replace=False)
        # TODO: remove sparse
        h_1 = self.gcn(seq1, adj)[idx]
        c = self.read(h_1, msk)
        c = self.sigm(c)
        h_2 = self.gcn(seq2, adj)[idx]
        ret = self.disc(c, h_1, h_2, samp_bias1, samp_bias2)
        return ret

    # Detach the return variables
    def embed(self, seq, adj, sparse, msk):
        h_1 = self.gcn(seq, adj, sparse)
        c = self.read(h_1, msk)

        return h_1.detach(), c.detach()


